{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d5195177",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "def dropout_layer (X,dropout):   #X为dropout层的输入，dropout为设置的丢弃概率\n",
    "    assert 0<=dropout<=1        #丢弃概率介于0，1之间\n",
    "    if dropout == 1:\n",
    "       return torch.zeros_like(X) #若丢弃概率为1，则X的全部项均被置0\n",
    "    if dropout == 0:\n",
    "       return X                   #若丢弃概率为0，不对X作丢弃操作，直接返回X\n",
    "    mask=(torch.Tensor(X.shape).uniform_(0,1)>dropout).float() #用uniform函数生成0-1间的随机实数，利用”>\"，将大于dropout的记为1，小于dropout的记为0，实现丢弃操作\n",
    "    return mask*X/(1-dropout) #将mask与X相乘实现丢弃操作，并除以(1-dropout)，这里不使用选中X中元素置0的原因是相乘操作相比选中操作更快"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8ff5ba10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],\n",
      "        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])\n",
      "tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],\n",
      "        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])\n",
      "tensor([[ 0.,  0.,  4.,  0.,  0., 10.,  0., 14.],\n",
      "        [16., 18., 20.,  0.,  0., 26.,  0., 30.]])\n",
      "tensor([[0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "#丢弃法测试\n",
    "X=torch.arange(16,dtype=torch.float32).reshape((2,8))\n",
    "print(X)\n",
    "print(dropout_layer (X,0.))  #丢弃率设置为0\n",
    "print(dropout_layer (X,0.5)) #丢弃率设置为0.5\n",
    "print(dropout_layer (X,1))   #丢弃率设置为1"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36torch040]",
   "language": "python",
   "name": "conda-env-py36torch040-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
